{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6VFCTA5Z7sB6"
      },
      "source": [
        "# Training on Wikiann"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HV3IYEdoHnNF"
      },
      "source": [
        "In this notebook, we use [MAD-X 2.0](https://arxiv.org/pdf/2012.15562.pdf) with a stacked language and task adapter setup to zero-shot cross-lingual transfer for NER.\n",
        "We use a NER adapter from [AdapterHub.ml](https://adapterhub.ml/explore) pre-trained on the **English** portion of the [WikiAnn](https://www.aclweb.org/anthology/P17-1178.pdf) dataset and transfer to **Guarani** with a pre-trained language adapter.\n",
        "This notebook is similar to the 'run_ner.py' example script in 'examples/pytorch/token-classification/'.\n",
        "\n",
        "First, let's install 'adapters' and other required packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T12:18:02.087873100Z",
          "start_time": "2023-08-17T12:17:54.490601100Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3AJLS4197g-i",
        "outputId": "58e147cd-538d-498c-ae67-7b6f7ce3e21a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/251.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.4/251.2 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m \u001b[32m245.8/251.2 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.2/251.2 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -Uq adapters\n",
        "!pip install -q datasets\n",
        "!pip install -q seqeval\n",
        "!pip install -Uq accelerate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "knPViYpFJzup"
      },
      "source": [
        "Next, we initialize the tokenizer and the model with the correct labels."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:54:31.730170400Z",
          "start_time": "2023-08-17T10:54:28.844625500Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vgQ5DC_48_r_",
        "outputId": "bb2d049b-6746-43e9-bd32-3a066536dac1"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']\n"
          ]
        }
      ],
      "source": [
        "from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoConfig\n",
        "from adapters import init\n",
        "\n",
        "#The labels for the NER task and the dictionaries to map the to ids or\n",
        "#the other way around\n",
        "labels = [\"O\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\"]\n",
        "id_2_label = {id_: label for id_, label in enumerate(labels)}\n",
        "label_2_id = {label: id_ for id_, label in enumerate(labels)}\n",
        "\n",
        "model_name = \"bert-base-multilingual-cased\"\n",
        "config = AutoConfig.from_pretrained(model_name, num_labels=len(labels), label2id=label_2_id, id2label=id_2_label)\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "model = AutoModelForTokenClassification.from_pretrained(model_name, config=config)\n",
        "\n",
        "# Enable adapter support\n",
        "init(model)\n",
        "print(model.get_labels())\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LYRG9o0HKv7_"
      },
      "source": [
        "Now, we load the task and the language adapter. For both adapters, we drop the adapter in the last layer following MAD-X 2.0. We then set both adapters as active adapters."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:55:18.106652Z",
          "start_time": "2023-08-17T10:55:07.121375100Z"
        },
        "id": "gZj6Jt-P7Ysd"
      },
      "outputs": [],
      "source": [
        "from adapters import AdapterConfig\n",
        "from adapters.composition import Stack\n",
        "\n",
        "target_language = \"gn\" # Choose any language that a bert-base-multilingual-cased language adapter is available for\n",
        "source_language = \"en\" # We support  \"en\", \"ja\", \"zh\", and \"ar\"\n",
        "\n",
        "adapter_config = AdapterConfig.load(\n",
        "    None,\n",
        "    leave_out=[11]\n",
        ")\n",
        "\n",
        "model.load_adapter(\n",
        "    \"wikiann/\" + source_language + \"@ukp\",\n",
        "    config=adapter_config,\n",
        "    load_as=\"wikiann\",\n",
        ")\n",
        "\n",
        "lang_adapter_name = model.load_adapter(\n",
        "    target_language + \"/wiki@ukp\",\n",
        "    load_as=target_language,\n",
        "    leave_out=[11],\n",
        ")\n",
        "# Set the adapters to be used in every forward pass\n",
        "model.set_active_adapters(Stack(lang_adapter_name, \"wikiann\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UyOE0neNKDmr"
      },
      "source": [
        "Next, we can download the dataset and initialize the trainings arguments."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:56:49.120227700Z",
          "start_time": "2023-08-17T10:55:35.178236700Z"
        },
        "id": "o-SUUa367TBr"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "from transformers import TrainingArguments\n",
        "\n",
        "datasets = load_dataset('wikiann', target_language)\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    per_device_eval_batch_size=64,\n",
        "    do_predict=True,\n",
        "    output_dir=\"ner_models/madx/\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DQ4-v6BuKJzl"
      },
      "source": [
        "This method is taken from the example script 'run_ner.py'. It prepares the input tokens such that they are tokenized by the correct tokenizer and the labels are adapted to the new tokenization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:57:18.725260400Z",
          "start_time": "2023-08-17T10:57:18.709640900Z"
        },
        "id": "bbfMD2EM8OAs"
      },
      "outputs": [],
      "source": [
        "# This method is adapted from the huggingface transformers run_ner.py example script\n",
        "# Tokenize all texts and align the labels with them.\n",
        "def tokenize_and_align_labels(examples):\n",
        "    text_column_name = \"tokens\"\n",
        "    label_column_name = \"ner_tags\"\n",
        "    tokenized_inputs = tokenizer(\n",
        "        examples[text_column_name],\n",
        "        padding=False,\n",
        "        truncation=True,\n",
        "        # We use this argument because the texts in our dataset are lists of words (with a label for each word).\n",
        "        is_split_into_words=True,\n",
        "    )\n",
        "    labels = []\n",
        "    for i, label in enumerate(examples[label_column_name]):\n",
        "        word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
        "        previous_word_idx = None\n",
        "        label_ids = []\n",
        "        for word_idx in word_ids:\n",
        "            # Special tokens have a word id that is None. We set the label to -100 so they are automatically\n",
        "            # ignored in the loss function.\n",
        "            if word_idx is None:\n",
        "                label_ids.append(-100)\n",
        "            # We set the label for the first token of each word.\n",
        "            elif word_idx != previous_word_idx:\n",
        "                label_ids.append(label[word_idx])\n",
        "            # For the other tokens in a word, we set the label to either the current label or -100, depending on\n",
        "            # the label_all_tokens flag.\n",
        "            else:\n",
        "                label_ids.append(-100)\n",
        "            previous_word_idx = word_idx\n",
        "\n",
        "        labels.append(label_ids)\n",
        "    tokenized_inputs[\"labels\"] = labels\n",
        "    return tokenized_inputs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lBiIMcWOKjih"
      },
      "source": [
        "We apply the previous method to the test dataset to prepare it for prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:57:33.497198700Z",
          "start_time": "2023-08-17T10:57:33.406508100Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 49,
          "referenced_widgets": [
            "d7bc633f2ede464aacc1e72fcfa12f7a",
            "9bbc7b3250bb413d87fd92481fbc124e",
            "9cd70a43ebb94a70acbba42eec2da49c",
            "b3d0bdd16c204efb821770b14e2cad90",
            "22da358e24834033a9f58897fcb5c8bc",
            "738643294a6a4c82a3b1302ac5c3eeae",
            "f3c854026aaa4a0489a2077af93ebb15",
            "df5aafeed72b4be9996671e9541ac56a",
            "4dc17f1d57684fc398a2834d6b39bf02",
            "b75ef14b0a7d4925ad06f22172a3b290",
            "5311b979ae3b4fba870e2de76d4e17cd"
          ]
        },
        "id": "xrnTO-hV8pRc",
        "outputId": "f6bd5820-c176-475e-fa10-8f449009587d"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d7bc633f2ede464aacc1e72fcfa12f7a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/100 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from transformers import DataCollatorForTokenClassification\n",
        "test_dataset = datasets[\"test\"]\n",
        "test_dataset = test_dataset.map(\n",
        "    tokenize_and_align_labels,\n",
        "    batched=True,\n",
        ")\n",
        "\n",
        "data_collator = DataCollatorForTokenClassification(tokenizer,)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rR8_a-MILFLZ"
      },
      "source": [
        "We use HuggingFace's `Trainer` class to evaluate zero-shot transfer on the WikiAnn test dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:57:37.702784100Z",
          "start_time": "2023-08-17T10:57:36.269889400Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84,
          "referenced_widgets": [
            "874cf4195a8e4c90ba6fcc7f36cb9483",
            "044148eb2ee6451c9f1fbb0137f59ac9",
            "0c6a31da49c5480d9a949da992bd95cd",
            "545bcd6670154ef18cb9603877cb707f",
            "06507f5e6d2c4319847750520af267e2",
            "cd1abe7ae13e418c8e3e04c9b26dfe9d",
            "3dad54731ab948ddb4b80a0d43a8c0a6",
            "bf106e11f0a749b182fc3f0e2df5429d",
            "4e02dd0f353047be90dbaf0fa2ecf0c3",
            "8872e1de590e49e7a648435eea7b4fe0",
            "c1845d1d27864c1887ff7b4882d06a6e"
          ]
        },
        "id": "H07g_-if9xb0",
        "outputId": "6c824656-0017-4eae-e630-506b71079647"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "<ipython-input-7-17d096b1ef3d>:7: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
            "  metric = load_metric(\"seqeval\")\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "874cf4195a8e4c90ba6fcc7f36cb9483",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading builder script:   0%|          | 0.00/2.47k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from adapters import AdapterTrainer\n",
        "from datasets import load_metric\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "# Metrics\n",
        "metric = load_metric(\"seqeval\")\n",
        "\n",
        "def compute_metrics(p):\n",
        "    predictions, labels = p\n",
        "    predictions = np.argmax(predictions, axis=2)\n",
        "    label_list = id_2_label\n",
        "\n",
        "    # Remove ignored index (special tokens)\n",
        "    true_predictions = [\n",
        "        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "    true_labels = [\n",
        "        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "\n",
        "    results = metric.compute(predictions=true_predictions, references=true_labels)\n",
        "    return {\n",
        "        \"precision\": results[\"overall_precision\"],\n",
        "        \"recall\": results[\"overall_recall\"],\n",
        "        \"f1\": results[\"overall_f1\"],\n",
        "        \"accuracy\": results[\"overall_accuracy\"],\n",
        "    }\n",
        "\n",
        "trainer = AdapterTrainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=None,\n",
        "    eval_dataset=test_dataset,\n",
        "    tokenizer=tokenizer,\n",
        "    data_collator=data_collator,\n",
        "    compute_metrics=compute_metrics,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F5lKBoZ9LPnb"
      },
      "source": [
        "Finally we can predict the labels for the test set and evaluate he predictions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:57:57.481411100Z",
          "start_time": "2023-08-17T10:57:47.722377900Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 193
        },
        "id": "HAxswwN3hmyw",
        "outputId": "84383f0c-8b7a-4c8d-b89b-4d6a4d240fbd"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2' max='2' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2/2 00:00]\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "{'eval_loss': 1.0064687728881836,\n",
              " 'eval_precision': 0.437125748502994,\n",
              " 'eval_recall': 0.6952380952380952,\n",
              " 'eval_f1': 0.5367647058823529,\n",
              " 'eval_accuracy': 0.784037558685446,\n",
              " 'eval_runtime': 5.3398,\n",
              " 'eval_samples_per_second': 18.727,\n",
              " 'eval_steps_per_second': 0.375}"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.evaluate()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10"
    },
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "metadata": {
          "collapsed": false
        },
        "source": []
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
